# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# This script, forward_pass_logits_checker.py, verifies the correctness of logits generated by the MaxText or HuggingFace
# implementation of a model by comparing them against a reference "golden" logits file for a set of input prompts.
#
# It supports multiple models and expects a reference file in JSON Lines (.jsonl) or pickle (.pickle or .pkl) format,
# specified via the --golden_logits_path argument. If not provided, it defaults to:
#   src/MaxText/test_assets/golden_data_<model_name>.jsonl
# For example:
#   src/MaxText/test_assets/golden_data_llama2-7b.jsonl
#
# Each line in the golden .jsonl file should be a dictionary with the following keys:
#   1. prompt: a string prompt, e.g., "I love to"
#   2. tokens: the token IDs resulting from tokenizing the prompt
#   3. logits: the expected (golden) logits output by the model for the given prompt
#
# The script runs a forward pass using the MaxText implementation and compares the resulting logits against the golden
# ones, asserting that they match within a tolerance.
#
# Utilities and Colab notebooks used to generate the golden logits are available in src/MaxText/scratch_code — for example:
#   src/MaxText/scratch_code/golden_llama2-7b_export.ipynb

"""Check if the logits generated by a model's src/MaxText/HF implementation matches golden logits for the same inputs"""

import argparse
import os
import sys
from pathlib import Path

import numpy as np
import jax
import jax.numpy as jnp

import torch.nn.functional as F
import torch

from google.cloud import storage

from transformers import AutoModelForCausalLM, AutoTokenizer

from MaxText.utils.ckpt_conversion.utils.hf_utils import (
    convert_jax_weight_to_torch,
)
from MaxText import max_logging
from MaxText import maxtext_utils
from MaxText import pyconfig
from MaxText.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN
from MaxText.globals import MAXTEXT_TEST_ASSETS_ROOT
from MaxText.layers import models
from MaxText.layers import quantizations


def upload_blob(bucket_name, source_file_name, destination_blob_name):
  """Uploads a file to the bucket."""
  storage_client = storage.Client()
  bucket = storage_client.get_bucket(bucket_name)
  blob = bucket.blob(destination_blob_name)
  blob.upload_from_filename(source_file_name)


def get_top_k_tokens_scores(logits_tensor, tokenizer_instance, k=10, description=""):
  """Get the top-k tokens and their scores from a given logits tensor."""
  max_logging.log(f"\n--- {description} top {k} tokens ---")
  collected_tokens = []
  tokens = []
  topk_results = torch.topk(logits_tensor[0], k=k)
  for i in range(k):
    tok_id = topk_results.indices[i].item()
    score = topk_results.values[i].item()
    tok = tokenizer_instance.decode(tok_id)
    collected_tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})
    tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})

  # Prepare data for logging
  table_str = f"| {'Token ID':<10} | {'Token':<20} | {'Score':<10} |\n"
  table_str += f"|{'-'*12}|{'-'*22}|{'-'*12}|\n"
  for d in collected_tokens:
    table_str += f"| {d['id']:<10} | {d['token']:<20} | {d['score']:<10.4f} |\n"
  max_logging.log(table_str)
  return tokens


def compare_top_tokens(converted_tokens, golden_tokens):
  """
  Compares two lists of top tokens and calculates similarity metrics.

  Args:
      converted_tokens: top tokens from the converted model.
      golden_tokens:  top tokens from the golden model.
  """
  # Extract the sets of token IDs for comparison
  converted_ids = {token["id"] for token in converted_tokens}
  golden_ids = {token["id"] for token in golden_tokens}

  # --- Metric 1: Overlap Count & Jaccard Similarity ---
  intersection = converted_ids.intersection(golden_ids)
  union = converted_ids.union(golden_ids)

  overlap_count = len(intersection)
  jaccard_similarity = overlap_count / len(union) if union else 0.0

  # --- Metric 2: Rank Agreement ---
  rank_matches = 0
  min_len = min(len(converted_tokens), len(golden_tokens))
  for i in range(min_len):
    if converted_tokens[i]["id"] == golden_tokens[i]["id"]:
      rank_matches += 1

  rank_agreement = (rank_matches / min_len) * 100 if min_len > 0 else 0.0

  metrics = {
      "overlap_count": f"{overlap_count}/{min_len}",
      "jaccard_similarity": jaccard_similarity,
      "rank_agreement_percentage": rank_agreement,
  }

  max_logging.log("\n--- Similarity Metrics of Top Tokens ---")
  table_str = f"| {'Metric':<30} | {'Value':<20} |\n"
  table_str += f"|{'-'*32}|{'-'*22}|\n"
  for key, value in metrics.items():
    table_str += f"| {key:<30} | {str(value):<20} |\n"
  max_logging.log(table_str)


def check_kl_divergence(model_logits, golden_logits, atol=0.02):
  """
  Calculates KL divergence D_KL(P_golden || Q_model) over a batch of sequences.

  Args:
      model_logits: Logits from the converted model (Batch, SeqLen, VocabSize).
      golden_logits: Logits from the golden model (Batch, SeqLen, VocabSize).
      token_size: The number of vocabulary entries to consider for the comparison.
                  (Effectively vocab_size_to_compare).
  """
  # 1. Select the relevant vocabulary slice from the logits.
  token_size = min(model_logits.shape[2], golden_logits.shape[2])
  model_logits_sliced = model_logits[..., :token_size]
  golden_logits_sliced = golden_logits[..., :token_size]

  # 2. Reshape
  b, s, v = model_logits_sliced.shape
  model_logits_reshaped = model_logits_sliced.view(b * s, v)
  golden_logits_reshaped = golden_logits_sliced.view(b * s, v)

  # 3. Get the probability distributions.
  golden_probabilities = F.softmax(golden_logits_reshaped, dim=-1)
  model_log_probabilities = F.log_softmax(model_logits_reshaped, dim=-1)

  # 4. Calculate avg KL divergence for all token distributions.
  # use 'batchmean'; the sum of the KL divergences for each token in the batch
  # and then divides by the number of tokens (b * s)
  kl_div_value = F.kl_div(
      input=model_log_probabilities,
      target=golden_probabilities,
      reduction="batchmean",  # Use 'batchmean' for the average KL per token.
      log_target=False,
  )

  max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.6f}")

  # To find the max KL divergence for any single token in the set
  # use reduction='none'.
  kl_divs_per_token = F.kl_div(
      input=model_log_probabilities, target=golden_probabilities, reduction="none", log_target=False
  ).sum(
      dim=-1
  )  # Sum over the vocab dim to get a single KL value per token

  max_kl_div = kl_divs_per_token.max()
  max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.6f}")

  assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}"


def get_data(golden_data_point, config):
  """Get the golden data for the test indexed at golden_data_index"""

  max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}")
  if config.use_multimodal:
    assert "pixel_values" in golden_data_point, "no image found in golden data while use_multimodal=True"
    pixel_values = np.asarray(golden_data_point["pixel_values"], dtype=np.float32)
    max_logging.log(f"pixel_values.shape = {pixel_values.shape}")
    model_prefix = config.model_name.split("-")[0]
    if model_prefix in ["gemma3"]:
      pixel_values = np.transpose(pixel_values, (1, 2, 0))
    elif model_prefix in ["llama4"]:
      pixel_values = pixel_values[None, :]
    pixel_values = np.stack([pixel_values for _ in range(config.global_batch_size_to_train_on)])
  else:
    pixel_values = None

  original_ids = np.asarray(golden_data_point["tokens"], dtype=np.int32)
  seq_len = len(original_ids)

  if seq_len > config.max_target_length:
    raise ValueError(
        f"Golden data sequence length ({seq_len}) is greater than max_target_length ({config.max_target_length})"
    )

  s = (config.global_batch_size_to_train_on, config.max_target_length)

  # Pad ids to max_target_length. MaxText expects 0 for padding.
  padded_ids = np.pad(original_ids, (0, config.max_target_length - seq_len), "constant", constant_values=0)
  ids = np.stack([padded_ids for _ in range(config.global_batch_size_to_train_on)])

  logits = np.asarray(golden_data_point["logits"], dtype=np.float32)
  if "formatted_prompt" in golden_data_point:
    prompt = golden_data_point["formatted_prompt"]
  else:
    prompt = golden_data_point["prompt"]
  max_logging.log(f' prompt="{prompt}" raw ids={original_ids}, logits.shape = {logits.shape}')

  decoder_segment_ids = np.zeros(s, dtype=np.int32)
  decoder_segment_ids[:, :seq_len] = DECODING_ACTIVE_SEQUENCE_INDICATOR
  decoder_positions = np.stack(
      [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]
  )
  return ids, decoder_segment_ids, decoder_positions, logits, seq_len, pixel_values


def main(config, test_args):  # pylint: disable=W0621
  """Test the Whole Model of model_name"""
  if not test_args.run_hf_model:
    """Comparing maxtext/huggingface model with pre-loaded golden logitis"""
    max_logging.log("Initializing MaxText model")
    init_rng = jax.random.PRNGKey(config.init_weights_seed)
    init_rng, rng1 = jax.random.split(init_rng)
    devices_array = maxtext_utils.create_device_mesh(config)
    mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
    quant = quantizations.configure_quantization(config)
    model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
    state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None)

    if test_args.golden_logits_path == "":
      input_golden_data_path = os.path.join(MAXTEXT_TEST_ASSETS_ROOT, f"golden_data_{config.model_name}.jsonl")
    else:
      input_golden_data_path = test_args.golden_logits_path
    input_golden_data_path = Path(input_golden_data_path)
    if input_golden_data_path.suffix == ".jsonl":
      max_logging.log("loading hf goldens from jsonl file")
      import jsonlines  # pylint: disable=import-outside-toplevel

      with jsonlines.open(input_golden_data_path, "r") as f:
        golden_data = list(f)
    elif input_golden_data_path.suffix in [".pickle", ".pkl"]:
      max_logging.log("loading hf goldens from pickle file")
      import pickle  # pylint: disable=import-outside-toplevel

      with open(input_golden_data_path, "rb") as f:
        golden_data = pickle.load(f)
    else:
      raise ValueError("golden_logits_path must end with .jsonl or .pickle/.pkl")
    max_logging.log(f"loaded {len(golden_data)} golden data points")
    all_data_to_save = []
    for golden_data_index, golden_data_point in enumerate(golden_data):
      max_logging.log(f"\n--- Comparing forward pass for golden data index: {golden_data_index} ---")
      ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len, images = get_data(golden_data_point, config)
      max_logging.log("maxtext forward pass")
      full_train_logits = model.apply(
          state.params,
          ids,
          decoder_positions,
          decoder_segment_ids,
          encoder_images=images,
          enable_dropout=False,
          rngs={"aqt": init_rng},
      )

      full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True)
      # if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size]
      if full_train_logits.ndim == 4:
        full_train_logits = jnp.reshape(full_train_logits, (-1, config.max_target_length, config.vocab_size))
      # Slice to original sequence length, [num_hosts * batch_size, seq_len, vocab_size]
      full_train_logits = full_train_logits[:, :seq_len, :]

      token_size = int(test_args.token_size) if test_args.token_size else seq_len
      if full_train_logits.shape[-1] != golden_logits.shape[-1]:
        max_logging.log(
            f"Vocab size mismatch: train logits vocab size {full_train_logits.shape[-1]}, "
            f"golden logits vocab size {golden_logits.shape[-1]}. "
            "Comparing up to the smaller vocab size."
        )
      min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1])
      # shape [seq_len, vocab_size]
      train_logits_slice = full_train_logits[0, :token_size, :min_vocab_size]
      golden_logits_slice = golden_logits[:token_size, :min_vocab_size]
      max_logging.log("\n[logits: token 2]")
      max_logging.log(f"{golden_logits_slice[2]=}")
      max_logging.log(f"{train_logits_slice[2]=}")

      # Calculate absolute and relative differences for detailed reporting
      abs_diff = jnp.abs(train_logits_slice - golden_logits_slice)

      # To avoid division by zero, add a small epsilon where golden_logits_slice is zero
      safe_golden_logits = jnp.where(golden_logits_slice == 0, 1e-8, golden_logits_slice)
      rel_diff = abs_diff / jnp.abs(safe_golden_logits)

      max_abs_diff_idx = jnp.unravel_index(jnp.argmax(abs_diff), abs_diff.shape)
      max_rel_diff_idx = jnp.unravel_index(jnp.argmax(rel_diff), rel_diff.shape)

      max_abs_diff_val = abs_diff[max_abs_diff_idx]
      max_rel_diff_val = rel_diff[max_rel_diff_idx]
      msg = (
          "\n[numerical difference]\n"
          f"Max absolute difference: {max_abs_diff_val:.6f} at index {max_abs_diff_idx}\n"
          f"  (Train: {train_logits_slice[max_abs_diff_idx]:.6f}, Golden: {golden_logits_slice[max_abs_diff_idx]:.6f})\n"
          f"Max relative difference: {max_rel_diff_val:.6f} at index {max_rel_diff_idx}\n"
          f"  (Train: {train_logits_slice[max_rel_diff_idx]:.6f}, Golden: {golden_logits_slice[max_rel_diff_idx]:.6f})"
      )
      max_logging.log(msg)

      if test_args.clip_logits_epsilon is not None:
        model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
        golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
      else:
        model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1)
        golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)

      max_logging.log("\n[probability: token 1]")
      max_logging.log(f"{golden_probabilities[1]=}")
      max_logging.log(f"{model_probabilities[1]=}")

      kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
      max_kl_div_val = jax.numpy.max(kl_div)
      max_kl_div_idx = jax.numpy.argmax(kl_div)
      max_logging.log(
          f"\n[KL divergence]\n"
          f"KL divergence = {kl_div}, max KL divergence = {max_kl_div_val} at index {max_kl_div_idx}, "
          f"the corresponding token id is {ids[0, max_kl_div_idx]}"
      )

      if jax.process_index() == 0 and test_args.output_logits_path:
        data_to_save = {
            "prompt": golden_data[golden_data_index]["prompt"],
            "tokens": ids[0, :seq_len].tolist(),
            "logits": full_train_logits[0].tolist(),
        }
        all_data_to_save.append(data_to_save)

      if test_args.atol is not None:
        max_logging.log("\n[test criteria]")
        max_logging.log(
            f"Checking Numerical Differences between train logits and golden logits against "
            f"atol={test_args.rtol} rtol={test_args.atol}."
        )
        rtol_val = float(test_args.rtol)
        atol_val = float(test_args.atol)
        assert jax.numpy.allclose(
            train_logits_slice, golden_logits_slice, rtol=rtol_val, atol=atol_val, equal_nan=False
        ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."

      if test_args.max_kl_div is not None:
        max_logging.log(
            f"Checking KL Divergence between train distribution and golden distribution against "
            f"threshold {test_args.max_kl_div}."
        )
        assert jax.numpy.all(
            kl_div < test_args.max_kl_div,
        ), (
            f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. "
            f"Max divergence: {jax.numpy.max(kl_div)}"
        )

  else:
    """Comparing maxtext model with HF model on-the-fly"""
    if test_args.hf_model_path == "":
      raise ValueError("run_hf_model requires hf_model_path")
    hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(test_args.hf_model_path)
    if "Llama-3.1" in test_args.hf_model_path:
      tokenizer.pad_token = tokenizer.eos_token

    init_rng = jax.random.PRNGKey(config.init_weights_seed)
    init_rng, rng1 = jax.random.split(init_rng)
    devices_array = maxtext_utils.create_device_mesh(config)
    mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
    quant = quantizations.configure_quantization(config)
    maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
    maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None)

    prompts = ["I love to", "Today is a", "What is the"]
    all_data_to_save = []
    for input_text in prompts:
      max_logging.log(f"\n--- Prompt: {input_text} ---")

      # Tokenize for HF
      inputs = tokenizer(
          input_text, return_tensors="pt", padding=True, max_length=config.max_target_length, truncation=True
      )
      actual_seq_len = inputs["input_ids"].shape[1]

      # Tokenize for MaxText
      mt_ids = jnp.asarray(inputs["input_ids"], dtype=jnp.int32)

      if mt_ids.shape[0] != config.global_batch_size_to_train_on:  # Ensure batch size matches
        mt_ids = jnp.repeat(mt_ids, config.global_batch_size_to_train_on // mt_ids.shape[0], axis=0)

      s = (config.global_batch_size_to_train_on, config.max_target_length)
      mt_decoder_segment_ids_full = jnp.zeros(s, dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR

      mt_decoder_segment_ids = mt_decoder_segment_ids_full[:, :actual_seq_len]

      # Create full decoder positions up to max_target_length
      mt_decoder_positions_full = jnp.stack(
          [jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)]
      )
      mt_decoder_positions = mt_decoder_positions_full[:, :actual_seq_len]

      # --- HF Forward Pass ---
      with torch.no_grad():
        hf_logits_torch = hf_model(**inputs).logits

      # --- MaxText Forward Pass ---
      mt_logits_jax = maxtext_model.apply(
          maxtext_state.params,
          mt_ids,
          mt_decoder_positions,
          mt_decoder_segment_ids,
          enable_dropout=False,
          rngs={"aqt": init_rng},
      )
      mt_logits_jax_sliced = mt_logits_jax[:, :actual_seq_len, :]
      mt_logits_torch = convert_jax_weight_to_torch(mt_logits_jax_sliced)

      # --- Compare logits for the last token prediction ---
      hf_last_token_logits = hf_logits_torch[:, -1, :]
      mt_last_token_logits = mt_logits_torch[:, -1, :]  # MaxText output already sliced to actual_seq_len

      tokens_maxtext = get_top_k_tokens_scores(mt_last_token_logits, tokenizer, k=10, description="MaxText model")
      tokens_hf = get_top_k_tokens_scores(hf_last_token_logits, tokenizer, k=10, description="HF model")
      compare_top_tokens(converted_tokens=tokens_maxtext, golden_tokens=tokens_hf)

      # --- Compare all logits in the sequence (for the first batch item) ---
      # Unsqueeze to add batch dimension for check_kl_divergence: [1, seq, vocab]
      check_kl_divergence(mt_logits_torch[0].unsqueeze(0), hf_logits_torch[0].unsqueeze(0), atol=test_args.max_kl_div)
      if jax.process_index() == 0 and test_args.output_logits_path:
        data_to_save = {
            "mt_logits": mt_logits_torch[0].tolist(),
            "hf_logits": hf_logits_torch[0].tolist(),
        }
        all_data_to_save.append(data_to_save)

  if jax.process_index() == 0 and test_args.output_logits_path:
    os.makedirs(os.path.dirname(test_args.output_logits_path), exist_ok=True)
    with jsonlines.open(test_args.output_logits_path, "a") as f:
      f.write(all_data_to_save)
    max_logging.log(f"Saved logits to {test_args.output_logits_path}")

    if test_args.gcs_output_logits_path:
      bucket_name = test_args.gcs_output_logits_path.split("/")[2]
      destination_blob_name = "/".join(
          test_args.gcs_output_logits_path.split("/")[3:] + test_args.output_logits_path.split("/")[-1:]
      )
      upload_blob(bucket_name, test_args.output_logits_path, destination_blob_name)
      max_logging.log(f"Uploaded logits to {test_args.gcs_output_logits_path}")


if __name__ == "__main__":
  jax.config.update("jax_default_prng_impl", "unsafe_rbg")
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

  parser = argparse.ArgumentParser()
  parser.add_argument("--atol", type=float, required=False, default=None)
  parser.add_argument("--rtol", type=float, required=False, default=1e-05)  # default from jnp.allclose
  parser.add_argument("--token_size", type=int, required=False)
  parser.add_argument("--max_kl_div", type=float, required=False, default=None)
  # golden_logits_path supports file format: json with suffix ".jsonl", and pickle with suffix ".pickle" or ".pkl"
  parser.add_argument("--golden_logits_path", type=str, required=False, default="")
  parser.add_argument("--hf_model_path", type=str, required=False, default="")
  parser.add_argument("--run_hf_model", type=bool, required=False, default=False)
  parser.add_argument("--output_logits_path", type=str, required=False, default="")
  parser.add_argument("--gcs_output_logits_path", type=str, required=False, default="")
  parser.add_argument("--clip_logits_epsilon", type=float, required=False, default=None)
  test_args, _ = parser.parse_known_args()

  # Remove args defined in this test file to avoid error from pyconfig
  model_args = sys.argv
  to_remove_args = [
      "--atol",
      "--rtol",
      "--token_size",
      "--max_kl_div",
      "--golden_logits_path",
      "--hf_model_path",
      "--run_hf_model",
      "--output_logits_path",
      "--gcs_output_logits_path",
      "--clip_logits_epsilon",
  ]
  for arg in to_remove_args:
    model_args = [s for s in model_args if not s.startswith(arg)]

  cfg = pyconfig.initialize(model_args)
  assert (
      test_args.atol is not None or test_args.max_kl_div is not None
  ), "At least one of --atol or --max_kl_div must be specified to define the test criteria."
  if cfg.use_multimodal:
    assert not test_args.run_hf_model, (
        "Multimodal does not support running hf model on-the-fly, please generate hf golden logits "
        "using generate_hf_golden_logits.py"
    )
  main(cfg, test_args)
